import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import pyautogui
import pytesseract

# Set the path to the Tesseract executable
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'

def merge_horizontal_elements(elements, max_horizontal_gap=10, max_vertical_difference=5):
    merged = []
    elements.sort(key=lambda e: (e['bbox'][1], e['bbox'][0]))  # Sort by y, then x
    
    for element in elements:
        if not merged or not is_horizontal_neighbor(merged[-1], element, max_horizontal_gap, max_vertical_difference):
            merged.append(element)
        else:
            last = merged[-1]
            x1 = min(last['bbox'][0], element['bbox'][0])
            y1 = min(last['bbox'][1], element['bbox'][1])
            x2 = max(last['bbox'][0] + last['bbox'][2], element['bbox'][0] + element['bbox'][2])
            y2 = max(last['bbox'][1] + last['bbox'][3], element['bbox'][1] + element['bbox'][3])
            merged[-1]['bbox'] = (x1, y1, x2 - x1, y2 - y1)
    
    return merged

def is_horizontal_neighbor(e1, e2, max_horizontal_gap, max_vertical_difference):
    x1, y1, w1, h1 = e1['bbox']
    x2, y2, w2, h2 = e2['bbox']
    return (abs(y1 - y2) <= max_vertical_difference and
            0 <= x2 - (x1 + w1) <= max_horizontal_gap)

def is_significant_overlap(e1, e2, threshold=0.7):
    x1, y1, w1, h1 = e1['bbox']
    x2, y2, w2, h2 = e2['bbox']
    
    # Calculate the intersection rectangle
    x_left = max(x1, x2)
    y_top = max(y1, y2)
    x_right = min(x1 + w1, x2 + w2)
    y_bottom = min(y1 + h1, y2 + h2)
    
    if x_right < x_left or y_bottom < y_top:
        return False
    
    intersection_area = (x_right - x_left) * (y_bottom - y_top)
    e1_area = w1 * h1
    e2_area = w2 * h2
    smaller_area = min(e1_area, e2_area)
    
    return intersection_area / smaller_area > threshold

def detect_elements(screenshot_path):
    # Read the image
    img = cv2.imread(screenshot_path)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # Apply multiple edge detection methods with different parameters
    edges1 = cv2.Canny(gray, 30, 100)
    edges2 = cv2.Canny(gray, 60, 200)
    edges3 = cv2.Canny(gray, 90, 250)
    edges = cv2.bitwise_or(cv2.bitwise_or(edges1, edges2), edges3)

    # Dilate edges to connect nearby elements
    kernel = np.ones((3,3), np.uint8)
    dilated = cv2.dilate(edges, kernel, iterations=2)

    # Find contours with hierarchy
    contours, hierarchy = cv2.findContours(dilated, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    # Filter and process contours
    elements = []
    min_area = 50  # Reduced minimum area
    max_area = img.shape[0] * img.shape[1] * 0.5  # Maximum area (50% of image)

    for i, contour in enumerate(contours):
        area = cv2.contourArea(contour)
        if min_area < area < max_area:
            x, y, w, h = cv2.boundingRect(contour)
            if w >= 5 and h >= 5:  # Allow smaller elements
                elements.append({
                    'id': len(elements) + 1,
                    'bbox': (x, y, w, h),
                    'area': area,
                    'hierarchy': hierarchy[0][i]
                })

    # Sort elements by area (largest first) and limit to top 100
    elements.sort(key=lambda e: e['area'], reverse=True)
    elements = elements[:100]

    # Merge horizontal elements (like parts of the URL bar)
    merged_elements = merge_horizontal_elements(elements)

    # Reassign IDs
    for i, element in enumerate(merged_elements):
        element['id'] = i + 1

    # Convert to PIL Image for drawing
    pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    draw = ImageDraw.Draw(pil_img)
    font = ImageFont.truetype("arialbd.ttf", 28)  # Use a single, larger font size for all labels

    for element in merged_elements:
        x, y, w, h = element['bbox']
        draw.rectangle([x, y, x+w, y+h], outline="red", width=2)
        
        # Determine label position
        if w <= 100 or h <= 40:  # For small elements, place label outside
            label_x = x + w + 5
            label_y = y + h // 2
            anchor = 'lm'  # Left-middle alignment
        else:
            label_x = x + w - 5
            label_y = y + 5
            anchor = 'ra'  # Right-top alignment

        # Draw a semi-transparent background for the label
        label_text = str(element['id'])
        text_bbox = draw.textbbox((label_x, label_y), label_text, font=font, anchor=anchor)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]
        
        if anchor == 'lm':
            bg_bbox = (label_x, label_y - text_height // 2, label_x + text_width, label_y + text_height // 2)
        else:
            bg_bbox = (label_x - text_width, label_y, label_x, label_y + text_height)
        
        draw.rectangle(bg_bbox, fill=(255, 255, 255, 180))  # Semi-transparent white background

        # Draw the label text
        draw.text((label_x, label_y), label_text, fill="red", font=font, anchor=anchor)

    # Save the annotated image
    annotated_path = 'annotated_screenshot.png'
    pil_img.save(annotated_path)

    return annotated_path, merged_elements

def capture_and_detect():
    # Capture screenshot
    screenshot = pyautogui.screenshot()
    screenshot_path = 'screenshot.png'
    screenshot.save(screenshot_path)

    # Detect elements
    return detect_elements(screenshot_path)

def find_text_on_screen(screenshot_path, text):
    try:
        # Read the image
        img = cv2.imread(screenshot_path)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Perform text detection
        data = pytesseract.image_to_data(gray, output_type=pytesseract.Output.DICT)

        for i, word in enumerate(data['text']):
            if text.lower() in word.lower():
                x = data['left'][i]
                y = data['top'][i]
                w = data['width'][i]
                h = data['height'][i]
                return (x + w//2, y + h//2)  # Return center of the word

        return None
    except Exception as e:
        print(f"Error in text detection: {e}")
        return None